{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "numpy.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.5"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pixDvex9KBqt"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "K16pBM8mKK7a",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TfRdquslKbO3"
      },
      "source": [
        "# 使用 tf.data 加载 NumPy 数据"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-uq3F0ggKlZb"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://tensorflow.google.cn/tutorials/load_data/numpy\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" />在 Tensorflow.org 上查看</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/zh-cn/tutorials/load_data/numpy.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" />在 Google Colab 运行</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/zh-cn/tutorials/load_data/numpy.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" />在 Github 上查看源代码</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/zh-cn/tutorials/load_data/numpy.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" />下载此 notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GEe3i16tQPjo",
        "colab_type": "text"
      },
      "source": [
        "Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为， 所以无法保证它们是最准确的，并且反映了最新的\n",
        "[官方英文文档](https://www.tensorflow.org/?hl=en)。如果您有改进此翻译的建议， 请提交 pull request 到\n",
        "[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文，请加入\n",
        "[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-0tqX8qkXZEj"
      },
      "source": [
        "本教程提供了将数据从 NumPy 数组加载到 `tf.data.Dataset`  的示例\n",
        "本示例从一个  `.npz` 文件中加载  MNIST 数据集。但是，本实例中 NumPy 数据的来源并不重要。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-Ze5IBx9clLB"
      },
      "source": [
        "## 安装"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "D1gtCQrnNk6b",
        "colab": {}
      },
      "source": [
        "try:\n",
        "  # Colab only\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "    pass\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "k6J3JzK5NxQ6",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        " \n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "G0yWiN8-cpDb"
      },
      "source": [
        "### 从 `.npz` 文件中加载"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "GLHNrFM6RWoM",
        "colab": {}
      },
      "source": [
        "DATA_URL = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz'\n",
        "\n",
        "path = tf.keras.utils.get_file('mnist.npz', DATA_URL)\n",
        "with np.load(path) as data:\n",
        "  train_examples = data['x_train']\n",
        "  train_labels = data['y_train']\n",
        "  test_examples = data['x_test']\n",
        "  test_labels = data['y_test']"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cCeCkvrDgCMM"
      },
      "source": [
        "## 使用 `tf.data.Dataset` 加载 NumPy 数组"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tslB0tJPgB-2"
      },
      "source": [
        "假设您有一个示例数组和相应的标签数组，请将两个数组作为元组传递给 `tf.data.Dataset.from_tensor_slices` 以创建 `tf.data.Dataset` 。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "QN_8wwc5R7Qm",
        "colab": {}
      },
      "source": [
        "train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels))\n",
        "test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6Rco85bbkDfN"
      },
      "source": [
        "## 使用该数据集"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "0dvl1uUukc4K"
      },
      "source": [
        "### 打乱和批次化数据集"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "GTXdRMPcSXZj",
        "colab": {}
      },
      "source": [
        "BATCH_SIZE = 64\n",
        "SHUFFLE_BUFFER_SIZE = 100\n",
        "\n",
        "train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)\n",
        "test_dataset = test_dataset.batch(BATCH_SIZE)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "w69Jl8k6lilg"
      },
      "source": [
        "### 建立和训练模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Uhxr8py4DkDN",
        "colab": {}
      },
      "source": [
        "model = tf.keras.Sequential([\n",
        "    tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
        "    tf.keras.layers.Dense(128, activation='relu'),\n",
        "    tf.keras.layers.Dense(10, activation='softmax')\n",
        "])\n",
        "\n",
        "model.compile(optimizer=tf.keras.optimizers.RMSprop(),\n",
        "                loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
        "                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "XLDzlPGgOHBx",
        "colab": {}
      },
      "source": [
        "model.fit(train_dataset, epochs=10)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2q82yN8mmKIE",
        "colab": {}
      },
      "source": [
        "model.evaluate(test_dataset)"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
